Libraries required for this analysis

knitr::opts_chunk$set(fig.align="center") 
library(rstanarm)
library(tidyverse)
library(tidybayes)
library(modelr) 
library(ggplot2)
library(magrittr)  
library(emmeans)
library(bayesplot)
library(brms)
library(gganimate)

theme_set(theme_light())

source('helper_functions.R')

In our experiment, we used a visualization recommendation algorithm (composed of one search algorithm and one oracle algorithm) to generate visualizations for the user on one of two datasets. We then measured the user’s accuracy on two tasks: Find Extremum and Retrieve Value.

Given a search algorithm (bfs or dfs), an oracle (compassql or dziban), and a dataset (birdstrikes or movies), we would like to predict a user’s chance of answering the Find Extremum task and the Retrieve Value tasks correctly. In addition, we would like to know if the choice of search algorithm and oracle has any meaningful impact on a user’s accuracy for these two tasks.

Read in and clean data

accuracy_data = read.csv('split_by_participant_groups/accuracy.csv')
accuracy_data$oracle = as.factor(accuracy_data$oracle)
accuracy_data$search = as.factor(accuracy_data$search)
accuracy_data$dataset = as.factor(accuracy_data$dataset)

models <- list()

draw_data <- list()

search_differences <- list()
oracle_differences <- list()

seed = 12

Building a Model for Accuracy Analysis

We derived a very weakly informative prior from our pilot studies

model <- brm(
  bf(
    accuracy ~ 0 + Intercept + oracle * search * dataset + task + participant_group + (1 | participant_id)
  ),
  data = accuracy_data,
  prior = c(prior(normal(0.8, .1), class = "b", coef = "Intercept"),
            prior(normal(0, 2.5), class = "b")),
  family = bernoulli(link = "logit"),
  warmup = 500,
  iter = 3000,
  chains = 2,
  cores = 2, 
  control = list(adapt_delta = 0.9),
  seed = seed,
  file = "models/accuracy"
)

Diagnostics + Model Evaluation

In the summary table, we want to see Rhat values close to 1.0 and Bulk_ESS in the thousands.

summary(model)
##  Family: bernoulli 
##   Links: mu = logit 
## Formula: accuracy ~ 0 + Intercept + oracle * search * dataset + task + participant_group + (1 | participant_id) 
##    Data: accuracy_data (Number of observations: 132) 
## Samples: 2 chains, each with iter = 3000; warmup = 500; thin = 1;
##          total post-warmup samples = 5000
## 
## Group-Level Effects: 
## ~participant_id (Number of levels: 66) 
##               Estimate Est.Error l-95% CI u-95% CI Rhat Bulk_ESS Tail_ESS
## sd(Intercept)     1.42      0.93     0.07     3.49 1.00      852     1751
## 
## Population-Level Effects: 
##                                      Estimate Est.Error l-95% CI u-95% CI Rhat
## Intercept                                0.83      0.10     0.64     1.03 1.00
## oracledziban                             2.68      1.17     0.63     5.14 1.00
## searchdfs                                2.03      1.08     0.08     4.31 1.00
## datasetmovies                            0.95      0.96    -0.78     3.03 1.00
## task2.RetrieveValue                      0.72      0.63    -0.46     1.98 1.00
## participant_groupstudent                 0.12      0.76    -1.20     1.79 1.00
## oracledziban:searchdfs                  -2.69      1.51    -5.65     0.29 1.00
## oracledziban:datasetmovies              -0.87      1.52    -3.77     2.18 1.00
## searchdfs:datasetmovies                  1.17      1.61    -1.97     4.38 1.00
## oracledziban:searchdfs:datasetmovies    -1.34      1.87    -5.15     2.36 1.00
##                                      Bulk_ESS Tail_ESS
## Intercept                                9485     3600
## oracledziban                             3025     2928
## searchdfs                                3393     3639
## datasetmovies                            3356     2767
## task2.RetrieveValue                      7854     3716
## participant_groupstudent                 3025     2087
## oracledziban:searchdfs                   3505     3384
## oracledziban:datasetmovies               3435     3863
## searchdfs:datasetmovies                  4827     3938
## oracledziban:searchdfs:datasetmovies     4819     3803
## 
## Samples were drawn using sampling(NUTS). For each parameter, Bulk_ESS
## and Tail_ESS are effective sample size measures, and Rhat is the potential
## scale reduction factor on split chains (at convergence, Rhat = 1).

Trace plots help us check whether there is evidence of non-convergence for model.

plot(model)

In our pairs plots, we want to make sure we don’t have highly correlated parameters (highly correlated parameters means that our model has difficulty differentiating the effect of such parameters).

pairs(
  model,
  pars = c("b_Intercept",
            "b_datasetmovies",
           "b_oracledziban",
           "b_searchdfs",
           "b_task2.RetrieveValue"),
  fixed = TRUE
)

pp_check(model, type = "dens_overlay", nsamples = 100)

A confusion matrix can be used to check our correct classification rate (a useful measure to see how well our model fits our data).

pred <- predict(model, type = "response")
pred <- if_else(pred[,1] > 0.5, 1, 0)
confusion_matrix <- table(pred, pull(accuracy_data, accuracy)) 
confusion_matrix
##     
## pred   0   1
##    1  10 122

Visualization of parameter effects via draws from our model posterior. The thicker line represents the 95% credible interval, while the thinner, longer line represents the 50% credible interval.

draw_data <- accuracy_data %>%
  add_fitted_draws(model, seed = seed, re_formula = NA, scale = "response") %>%
  group_by(search, oracle, dataset, task, .draw)

plot_data <- draw_data
plot_data$oracle<- gsub('compassql', 'CompassQL', plot_data$oracle)
plot_data$oracle<- gsub('dziban', 'Dziban', plot_data$oracle)
plot_data$search<- gsub('bfs', 'BFS', plot_data$search)
plot_data$search<- gsub('dfs', 'DFS', plot_data$search)

plot_data$condition <- paste(plot_data$oracle, plot_data$search, sep=" + ")

draw_plot <- posterior_draws_plot(plot_data, "dataset", TRUE, "Predicted accuracy (p_correct)", "Oracle/Search Combination")

draw_plot

Since the credible intervals on our plot overlap, we can use mean_qi to get the numeric boundaries for the different intervals.

fit_info <-  draw_data %>% group_by(search, oracle, dataset, task) %>% mean_qi(.value, .width = c(.95, .5))
fit_info
## # A tibble: 32 x 10
## # Groups:   search, oracle, dataset [8]
##    search oracle  dataset  task     .value .lower .upper .width .point .interval
##    <fct>  <fct>   <fct>    <chr>     <dbl>  <dbl>  <dbl>  <dbl> <chr>  <chr>    
##  1 bfs    compas… birdstr… 1. Find…  0.698  0.440  0.917   0.95 mean   qi       
##  2 bfs    compas… birdstr… 2. Retr…  0.809  0.538  0.966   0.95 mean   qi       
##  3 bfs    compas… movies   1. Find…  0.827  0.523  0.986   0.95 mean   qi       
##  4 bfs    compas… movies   2. Retr…  0.897  0.656  0.994   0.95 mean   qi       
##  5 bfs    dziban  birdstr… 1. Find…  0.953  0.806  0.998   0.95 mean   qi       
##  6 bfs    dziban  birdstr… 2. Retr…  0.973  0.876  0.999   0.95 mean   qi       
##  7 bfs    dziban  movies   1. Find…  0.949  0.780  0.999   0.95 mean   qi       
##  8 bfs    dziban  movies   2. Retr…  0.973  0.873  1.00    0.95 mean   qi       
##  9 dfs    compas… birdstr… 1. Find…  0.920  0.711  0.996   0.95 mean   qi       
## 10 dfs    compas… birdstr… 2. Retr…  0.955  0.821  0.998   0.95 mean   qi       
## # … with 22 more rows
## Saving 7 x 5 in image

Differences Between Conditions

predictive_data  <- accuracy_data %>%
    add_fitted_draws(model, seed = seed, re_formula = NA, scale = "response")

Differences in search algorithms:

search_differences <- expected_diff_in_mean_plot(predictive_data, "search", "Difference in Mean Accuracy (p_correct)",  "Task", "dataset")
## `summarise()` regrouping output by 'search', 'task', 'dataset' (override with `.groups` argument)
search_differences$plot

We can double-check the boundaries of the credible intervals to be sure whether or not the interval contains zero.

search_differences$intervals
## # A tibble: 8 x 9
## # Groups:   search, dataset [2]
##   search  dataset   task     difference  .lower   .upper .width .point .interval
##   <chr>   <fct>     <chr>         <dbl>   <dbl>    <dbl>  <dbl> <chr>  <chr>    
## 1 bfs - … birdstri… 1. Find…    -0.0886 -0.224   0.0698    0.95 mean   qi       
## 2 bfs - … birdstri… 2. Retr…    -0.0609 -0.192   0.0338    0.95 mean   qi       
## 3 bfs - … movies    1. Find…    -0.0510 -0.208   0.0888    0.95 mean   qi       
## 4 bfs - … movies    2. Retr…    -0.0313 -0.146   0.0540    0.95 mean   qi       
## 5 bfs - … birdstri… 1. Find…    -0.0886 -0.136  -0.0477    0.5  mean   qi       
## 6 bfs - … birdstri… 2. Retr…    -0.0609 -0.0909 -0.0245    0.5  mean   qi       
## 7 bfs - … movies    1. Find…    -0.0510 -0.0929 -0.00687   0.5  mean   qi       
## 8 bfs - … movies    2. Retr…    -0.0313 -0.0543 -0.00349   0.5  mean   qi

Differences in oracle:

oracle_differences <- expected_diff_in_mean_plot(predictive_data, "oracle", "Difference in Mean Accuracy (p_correct)",  "Task", "dataset")
## `summarise()` regrouping output by 'oracle', 'task', 'dataset' (override with `.groups` argument)
oracle_differences$plot

We can double-check the boundaries of the credible intervals to be sure whether or not the interval contains zero.

oracle_differences$intervals
## # A tibble: 8 x 9
## # Groups:   oracle, dataset [2]
##   oracle     dataset  task     difference  .lower .upper .width .point .interval
##   <chr>      <fct>    <chr>         <dbl>   <dbl>  <dbl>  <dbl> <chr>  <chr>    
## 1 dziban - … birdstr… 1. Find…     0.121  -0.0291 0.272    0.95 mean   qi       
## 2 dziban - … birdstr… 2. Retr…     0.0789 -0.0162 0.215    0.95 mean   qi       
## 3 dziban - … movies   1. Find…     0.0183 -0.135  0.170    0.95 mean   qi       
## 4 dziban - … movies   2. Retr…     0.0132 -0.0801 0.121    0.95 mean   qi       
## 5 dziban - … birdstr… 1. Find…     0.121   0.0762 0.166    0.5  mean   qi       
## 6 dziban - … birdstr… 2. Retr…     0.0789  0.0390 0.110    0.5  mean   qi       
## 7 dziban - … movies   1. Find…     0.0183 -0.0212 0.0592   0.5  mean   qi       
## 8 dziban - … movies   2. Retr…     0.0132 -0.0102 0.0348   0.5  mean   qi

Differences in participant group (student vs professional):

participant_group_differences <- expected_diff_in_mean_plot(predictive_data, "participant_group", "Difference in Mean Accuracy (p_correct)", "Task", NULL)
## `summarise()` regrouping output by 'participant_group', 'task', 'dataset' (override with `.groups` argument)
participant_group_differences$plot

We can double-check the boundaries of the credible intervals to be sure whether or not the interval contains zero.

participant_group_differences$intervals
## # A tibble: 4 x 8
## # Groups:   participant_group [1]
##   participant_group  task      difference  .lower .upper .width .point .interval
##   <chr>              <chr>          <dbl>   <dbl>  <dbl>  <dbl> <chr>  <chr>    
## 1 student - profess… 1. Find …  -0.000481 -0.120  0.113    0.95 mean   qi       
## 2 student - profess… 2. Retri…   0.00170  -0.0758 0.0926   0.95 mean   qi       
## 3 student - profess… 1. Find …  -0.000481 -0.0345 0.0339   0.5  mean   qi       
## 4 student - profess… 2. Retri…   0.00170  -0.0191 0.0201   0.5  mean   qi
participant_group_differences_dataset <- expected_diff_in_mean_plot(predictive_data, "participant_group", "Difference in Mean Accuracy (p_correct)", "Task", "dataset")
## `summarise()` regrouping output by 'participant_group', 'task', 'dataset' (override with `.groups` argument)
participant_group_differences_dataset$plot